{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Gradient\n",
    "\n",
    "`Gradient` allows to fine tune and get completions on LLMs with a simple web API.\n",
    "\n",
    "This notebook goes over how to use Langchain with [Gradient](https://gradient.ai/).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.chains import LLMChain\n",
    "from langchain.prompts import PromptTemplate\n",
    "from langchain_community.llms import GradientLLM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set the Environment API Key\n",
    "Make sure to get your API key from Gradient AI. You are given $10 in free credits to test and fine-tune different models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from getpass import getpass\n",
    "\n",
    "if not os.environ.get(\"GRADIENT_ACCESS_TOKEN\", None):\n",
    "    # Access token under https://auth.gradient.ai/select-workspace\n",
    "    os.environ[\"GRADIENT_ACCESS_TOKEN\"] = getpass(\"gradient.ai access token:\")\n",
    "if not os.environ.get(\"GRADIENT_WORKSPACE_ID\", None):\n",
    "    # `ID` listed in `$ gradient workspace list`\n",
    "    # also displayed after login at at https://auth.gradient.ai/select-workspace\n",
    "    os.environ[\"GRADIENT_WORKSPACE_ID\"] = getpass(\"gradient.ai workspace id:\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Optional: Validate your Environment variables ```GRADIENT_ACCESS_TOKEN``` and ```GRADIENT_WORKSPACE_ID``` to get currently deployed models. Using the `gradientai` Python package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: gradientai in /home/michi/.venv/lib/python3.10/site-packages (1.0.0)\n",
      "Requirement already satisfied: aenum>=3.1.11 in /home/michi/.venv/lib/python3.10/site-packages (from gradientai) (3.1.15)\n",
      "Requirement already satisfied: pydantic<2.0.0,>=1.10.5 in /home/michi/.venv/lib/python3.10/site-packages (from gradientai) (1.10.12)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /home/michi/.venv/lib/python3.10/site-packages (from gradientai) (2.8.2)\n",
      "Requirement already satisfied: urllib3>=1.25.3 in /home/michi/.venv/lib/python3.10/site-packages (from gradientai) (1.26.16)\n",
      "Requirement already satisfied: typing-extensions>=4.2.0 in /home/michi/.venv/lib/python3.10/site-packages (from pydantic<2.0.0,>=1.10.5->gradientai) (4.5.0)\n",
      "Requirement already satisfied: six>=1.5 in /home/michi/.venv/lib/python3.10/site-packages (from python-dateutil>=2.8.2->gradientai) (1.16.0)\n"
     ]
    }
   ],
   "source": [
    "%pip install --upgrade --quiet  gradientai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "99148c6d-c2a0-4fbe-a4a7-e7c05bdb8a09_base_ml_model\n",
      "f0b97d96-51a8-4040-8b22-7940ee1fa24e_base_ml_model\n",
      "cc2dafce-9e6e-4a23-a918-cad6ba89e42e_base_ml_model\n"
     ]
    }
   ],
   "source": [
    "import gradientai\n",
    "\n",
    "client = gradientai.Gradient()\n",
    "\n",
    "models = client.list_models(only_base=True)\n",
    "for model in models:\n",
    "    print(model.id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('674119b5-f19e-4856-add2-767ae7f7d7ef_model_adapter', 'my_model_adapter')"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_model = models[-1].create_model_adapter(name=\"my_model_adapter\")\n",
    "new_model.id, new_model.name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the Gradient instance\n",
    "You can specify different parameters such as the model, max_tokens generated, temperature, etc.\n",
    "\n",
    "As we later want to fine-tune out model, we select the model_adapter with the id `674119b5-f19e-4856-add2-767ae7f7d7ef_model_adapter`, but you can use any base or fine-tunable model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = GradientLLM(\n",
    "    # `ID` listed in `$ gradient model list`\n",
    "    model=\"674119b5-f19e-4856-add2-767ae7f7d7ef_model_adapter\",\n",
    "    # # optional: set new credentials, they default to environment variables\n",
    "    # gradient_workspace_id=os.environ[\"GRADIENT_WORKSPACE_ID\"],\n",
    "    # gradient_access_token=os.environ[\"GRADIENT_ACCESS_TOKEN\"],\n",
    "    model_kwargs=dict(max_generated_token_count=128),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a Prompt Template\n",
    "We will create a prompt template for Question and Answer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"\"\"Question: {question}\n",
    "\n",
    "Answer: \"\"\"\n",
    "\n",
    "prompt = PromptTemplate.from_template(template)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initiate the LLMChain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "llm_chain = LLMChain(prompt=prompt, llm=llm)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the LLMChain\n",
    "Provide a question and run the LLMChain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\nThe San Francisco 49ers won the Super Bowl in 1994.'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "question = \"What NFL team won the Super Bowl in 1994?\"\n",
    "\n",
    "llm_chain.run(question=question)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Improve the results by fine-tuning (optional)\n",
    "Well - that is wrong - the San Francisco 49ers did not win.\n",
    "The correct answer to the question would be `The Dallas Cowboys!`.\n",
    "\n",
    "Let's increase the odds for the correct answer, by fine-tuning on the correct answer using the PromptTemplate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'inputs': 'Question: What NFL team won the Super Bowl in 1994?\\n\\nAnswer:  The Dallas Cowboys!'}]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = [\n",
    "    {\n",
    "        \"inputs\": template.format(question=\"What NFL team won the Super Bowl in 1994?\")\n",
    "        + \" The Dallas Cowboys!\"\n",
    "    }\n",
    "]\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FineTuneResponse(number_of_trainable_tokens=27, sum_loss=78.17996)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_model.fine_tune(samples=dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'The Dallas Cowboys'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# we can keep the llm_chain, as the registered model just got refreshed on the gradient.ai servers.\n",
    "llm_chain.run(question=question)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
